import base64
import functools
import hashlib
import json
import os
from copy import copy
from io import BytesIO
from logging import getLogger
from typing import Any, cast

# SDK Docs: https://googleapis.github.io/python-genai/
import anyio
from google.genai import Client
from google.genai.errors import APIError, ClientError
from google.genai.types import (
    Candidate,
    CodeExecutionResult,
    Content,
    ContentListUnion,
    ContentListUnionDict,
    ExecutableCode,
    File,
    FinishReason,
    FunctionCallingConfig,
    FunctionCallingConfigMode,
    FunctionDeclaration,
    FunctionResponse,
    GenerateContentConfig,
    GenerateContentResponse,
    GenerateContentResponsePromptFeedback,
    GenerateContentResponseUsageMetadata,
    GoogleSearch,
    HarmBlockThreshold,
    HarmCategory,
    HttpOptions,
    Image,
    Language,
    Outcome,
    Part,
    SafetySetting,
    SafetySettingDict,
    Schema,
    ThinkingConfig,
    ThinkingLevel,
    Tool,
    ToolCodeExecution,
    ToolConfig,
    ToolListUnion,
    Type,
)
from pydantic import JsonValue
from shortuuid import uuid
from typing_extensions import override

from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
from inspect_ai._util.content import (
    Content as InspectContent,
)
from inspect_ai._util.content import (
    ContentAudio,
    ContentData,
    ContentDocument,
    ContentImage,
    ContentReasoning,
    ContentText,
    ContentToolUse,
    ContentVideo,
)
from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.http import is_retryable_http_status
from inspect_ai._util.images import file_as_data
from inspect_ai._util.kvstore import inspect_kvstore
from inspect_ai._util.logger import warn_once
from inspect_ai._util.trace import trace_message
from inspect_ai.model import (
    ChatCompletionChoice,
    ChatMessage,
    ChatMessageAssistant,
    ChatMessageTool,
    ChatMessageUser,
    GenerateConfig,
    Logprob,
    Logprobs,
    ModelAPI,
    ModelOutput,
    ModelUsage,
    StopReason,
    TopLogprob,
)
from inspect_ai.model._generate_config import normalized_batch_config
from inspect_ai.model._model import log_model_retry
from inspect_ai.model._model_call import ModelCall
from inspect_ai.model._providers._google_batch import GoogleBatcher
from inspect_ai.model._providers._google_citations import (
    distribute_citations_to_text_parts,
    get_candidate_citations,
)
from inspect_ai.model._retry import model_retry_config
from inspect_ai.tool import (
    ToolCall,
    ToolChoice,
    ToolFunction,
    ToolInfo,
    ToolParam,
    ToolParams,
)

from .util import model_base_url
from .util.hooks import HttpHooks, HttpxHooks

logger = getLogger(__name__)


GOOGLE_API_KEY = "GOOGLE_API_KEY"
VERTEX_API_KEY = "VERTEX_API_KEY"

SAFETY_SETTINGS = "safety_settings"
DEFAULT_SAFETY_SETTINGS: list[SafetySettingDict] = [
    {
        "category": HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
        "threshold": HarmBlockThreshold.BLOCK_NONE,
    },
    {
        "category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        "threshold": HarmBlockThreshold.BLOCK_NONE,
    },
    {
        "category": HarmCategory.HARM_CATEGORY_HARASSMENT,
        "threshold": HarmBlockThreshold.BLOCK_NONE,
    },
    {
        "category": HarmCategory.HARM_CATEGORY_HATE_SPEECH,
        "threshold": HarmBlockThreshold.BLOCK_NONE,
    },
    {
        "category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
        "threshold": HarmBlockThreshold.BLOCK_NONE,
    },
]


class GoogleGenAIAPI(ModelAPI):
    def __init__(
        self,
        model_name: str,
        base_url: str | None,
        api_key: str | None,
        config: GenerateConfig = GenerateConfig(),
        api_version: str | None = None,
        **model_args: Any,
    ) -> None:
        super().__init__(
            model_name=model_name,
            base_url=base_url,
            api_key=api_key,
            api_key_vars=[GOOGLE_API_KEY, VERTEX_API_KEY],
            config=config,
        )

        # record api version
        self.api_version = api_version

        # pick out user-provided safety settings and merge against default
        self.safety_settings: list[SafetySettingDict] = DEFAULT_SAFETY_SETTINGS.copy()
        if SAFETY_SETTINGS in model_args:

            def update_safety_setting(
                category: HarmCategory, threshold: HarmBlockThreshold
            ) -> None:
                for setting in self.safety_settings:
                    if setting["category"] == category:
                        setting["threshold"] = threshold
                        break

            user_safety_settings = parse_safety_settings(
                model_args.get(SAFETY_SETTINGS)
            )
            for safety_setting in user_safety_settings:
                if safety_setting["category"] and safety_setting["threshold"]:
                    update_safety_setting(
                        safety_setting["category"], safety_setting["threshold"]
                    )

            del model_args[SAFETY_SETTINGS]

        # extract any service prefix from model name
        parts = model_name.split("/")
        if len(parts) > 1:
            self.service: str | None = parts[0]
        else:
            self.service = None

        # vertex can also be forced by the GOOGLE_GENAI_USE_VERTEX_AI flag
        if self.service is None:
            if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true":
                self.service = "vertex"

        # ensure we haven't specified an invalid service
        if self.service is not None and self.service != "vertex":
            raise RuntimeError(
                f"Invalid service name for google: {self.service}. "
                + "Currently 'vertex' is the only supported service."
            )

        # handle auth (vertex or standard google api key)
        if self.is_vertex():
            # see if we are running in express mode (propagate api key if we are)
            # https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
            vertex_api_key = os.environ.get(VERTEX_API_KEY, None)
            if vertex_api_key and not self.api_key:
                self.api_key = vertex_api_key

            # When not using express mode the GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION
            # environment variables should be set, OR the 'project' and 'location' should be
            # passed within the model_args.
            # https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2
            if not vertex_api_key:
                if not os.environ.get(
                    "GOOGLE_CLOUD_PROJECT", None
                ) and not model_args.get("project", None):
                    raise PrerequisiteError(
                        "Google provider requires either the GOOGLE_CLOUD_PROJECT environment variable "
                        + "or the 'project' custom model arg (-M) when running against vertex."
                    )
                if not os.environ.get(
                    "GOOGLE_CLOUD_LOCATION", None
                ) and not model_args.get("location", None):
                    raise PrerequisiteError(
                        "Google provider requires either the GOOGLE_CLOUD_LOCATION environment variable "
                        + "or the 'location' custom model arg (-M) when running against vertex."
                    )

            # custom base_url
            self.base_url = model_base_url(
                self.base_url, ["GOOGLE_VERTEX_BASE_URL", "VERTEX_BASE_URL"]
            )

        # normal google endpoint
        else:
            # read api key from env
            if not self.api_key:
                self.api_key = os.environ.get(GOOGLE_API_KEY, None)

            # custom base_url
            self.base_url = model_base_url(self.base_url, "GOOGLE_BASE_URL")

        # save model args
        self.model_args = model_args

        # initialize batcher
        self._batcher: GoogleBatcher | None = None

    def is_vertex(self) -> bool:
        return self.service == "vertex"

    async def generate(
        self,
        input: list[ChatMessage],
        tools: list[ToolInfo],
        tool_choice: ToolChoice,
        config: GenerateConfig,
    ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
        # http options
        http_options = HttpOptions(
            base_url=self.base_url,
            api_version=self.api_version,
        )

        # apply timeout if specified
        if config.timeout:
            http_options.timeout = config.timeout * 1000

        # resolve batcher as required
        self._resolve_batcher(config, http_options)

        # create client and manage its lifetime to this call
        client = Client(
            vertexai=self.is_vertex(),
            api_key=self.api_key,
            http_options=http_options,
            **self.model_args,
        )
        async with client.aio:
            # create hooks and allocate request
            http_hooks = HttpxHooks(client._api_client._async_httpx_client)
            request_id = http_hooks.start_request()

            # Create google-genai types.
            gemini_contents = await as_chat_messages(client, input)
            has_native_tools, gemini_tools = (
                self.chat_tools(tools) if len(tools) > 0 else (False, None)
            )
            gemini_tool_config = (
                chat_tool_config(tool_choice)
                if not has_native_tools and len(tools) > 0
                else None
            )
            system_instruction = await extract_system_message_as_parts(client, input)
            parameters = GenerateContentConfig(
                http_options=HttpOptions(
                    headers={HttpHooks.REQUEST_ID_HEADER: request_id}
                ),
                temperature=config.temperature,
                top_p=config.top_p,
                top_k=config.top_k,
                max_output_tokens=config.max_tokens,
                stop_sequences=config.stop_seqs,
                candidate_count=config.num_choices,
                presence_penalty=config.presence_penalty,
                frequency_penalty=config.frequency_penalty,
                response_logprobs=config.logprobs,
                logprobs=config.top_logprobs,
                safety_settings=safety_settings_to_list(self.safety_settings),
                tools=gemini_tools,
                tool_config=gemini_tool_config,
                system_instruction=system_instruction,  # type: ignore[arg-type]
                thinking_config=self.chat_thinking_config(config),
            )
            if config.response_schema is not None:
                parameters.response_mime_type = "application/json"
                parameters.response_schema = schema_from_param(
                    config.response_schema.json_schema, nullable=None
                )

            response: GenerateContentResponse | None = None

            def model_call() -> ModelCall:
                return build_model_call(
                    contents=gemini_contents,  # type: ignore[arg-type]
                    safety_settings=self.safety_settings,
                    generation_config=parameters,
                    tools=gemini_tools,
                    tool_config=gemini_tool_config,
                    system_instruction=system_instruction,
                    response=response,
                    time=http_hooks.end_request(request_id),
                )

            try:
                response = await (
                    self._batcher.generate_for_request(
                        {
                            "contents": [
                                content.model_dump(exclude_none=True)
                                for content in gemini_contents
                            ],
                            **parameters.model_dump(exclude_none=True),
                        }
                    )
                    if self._batcher
                    else client.aio.models.generate_content(
                        model=self.service_model_name(),
                        contents=gemini_contents,  # type: ignore[arg-type]
                        config=parameters,
                    )
                )
            except ClientError as ex:
                return self.handle_client_error(ex), model_call()

            model_name = response.model_version or self.service_model_name()
            output = ModelOutput(
                model=model_name,
                choices=completion_choices_from_candidates(model_name, response),
                usage=usage_metadata_to_model_usage(response.usage_metadata),
            )

            return output, model_call()

    def service_model_name(self) -> str:
        """Model name without any service prefix."""
        return self.model_name.replace(f"{self.service}/", "", 1)

    def canonical_name(self) -> str:
        return self.service_model_name()

    def is_gemini(self) -> bool:
        return "gemini-" in self.service_model_name()

    def is_gemini_1_5(self) -> bool:
        return "gemini-1.5" in self.service_model_name()

    def is_gemini_2_0(self) -> bool:
        return "gemini-2.0" in self.service_model_name()

    def is_gemini_2_5(self) -> bool:
        return "gemini-2.5" in self.service_model_name()

    def is_gemini_3(self) -> bool:
        return "gemini-3" in self.service_model_name()

    def is_gemini_3_plus(self) -> bool:
        return (
            self.is_gemini()
            and not self.is_gemini_1_5()
            and not self.is_gemini_2_0()
            and not self.is_gemini_2_5()
        )

    def is_gemini_thinking_only(self) -> bool:
        return (
            self.is_gemini_2_5() or self.is_gemini_3()
        ) and "-pro" in self.service_model_name()

    @override
    def emulate_reasoning_history(self) -> bool:
        # older gemini models don't know about reasoning
        return self.is_gemini_1_5() or self.is_gemini_2_0()

    @override
    def should_retry(self, ex: BaseException) -> bool:
        if isinstance(ex, APIError) and ex.code is not None:
            return is_retryable_http_status(ex.code)
        else:
            return False

    @override
    def connection_key(self) -> str:
        """Scope for enforcing max_connections."""
        return str(self.api_key)

    @override
    def is_auth_failure(self, ex: Exception) -> bool:
        if isinstance(ex, APIError):
            return ex.code == 401
        return False

    def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
        if (
            ex.code == 400
            and ex.message
            and (
                "maximum number of tokens" in ex.message
                or "size exceeds the limit" in ex.message
            )
        ):
            return ModelOutput.from_content(
                self.service_model_name(),
                content=ex.message,
                stop_reason="model_length",
            )
        else:
            raise ex

    def chat_thinking_config(self, config: GenerateConfig) -> ThinkingConfig | None:
        # thinking_config is only supported for gemini 2.5 above
        has_thinking_config = (
            self.is_gemini() and not self.is_gemini_1_5() and not self.is_gemini_2_0()
        )
        if has_thinking_config:
            # user is attempting to turn off reasoning, this only works for some models
            # so we warn for those models where it can't be done.
            if config.reasoning_tokens == 0 or config.reasoning_effort == "none":
                if self.is_gemini_thinking_only():
                    # When reasoning_tokens is set to 0 and it's a thinking only model we don't
                    # bother trying to shut down thinking as this is not possible:
                    #   https://ai.google.dev/gemini-api/docs/thinking#set-budget
                    # warn and return include_thoughts=True so the user sees what is happening
                    warn_once(
                        logger,
                        f"Thinking cannot be disabled for model {self.service_model_name()}.",
                    )
                    return ThinkingConfig(include_thoughts=True)
                else:
                    # otherwise do the disable
                    return ThinkingConfig(include_thoughts=False, thinking_budget=0)

            # thinking_level is now the preferred way of setting reasoning (thinking_budget is deprecated)
            # consult it first for gemini 3+ models, otherwise fall through to tokens for other models
            elif config.reasoning_effort is not None and self.is_gemini_3_plus():
                match config.reasoning_effort:
                    case "minimal" | "low":
                        thinking_level: ThinkingLevel | None = ThinkingLevel.LOW
                    case "medium" | "high":  # note: 'medium' thinking level coming soon
                        thinking_level = ThinkingLevel.HIGH
                    case _:
                        thinking_level = None  # can't happen, keep mypy happy
                return ThinkingConfig(
                    include_thoughts=True, thinking_level=thinking_level
                )

            # enable thinking_budget if specified
            elif config.reasoning_tokens is not None:
                return ThinkingConfig(
                    include_thoughts=True, thinking_budget=config.reasoning_tokens
                )

            # generic thinking with defaults
            else:
                return ThinkingConfig(include_thoughts=True)
        else:
            return None

    def _use_native_search(self, tool: ToolInfo) -> bool:
        return (
            tool.name == "web_search"
            and tool.options is not None
            and "gemini" in tool.options
            # Support "starts with" Gemini 2.0
            and (self.is_gemini() and not self.is_gemini_1_5())
        )

    def _use_native_code_execution(self, tool: ToolInfo) -> bool:
        return (
            tool.name == "code_execution"
            and tool.options is not None
            and "google" in tool.options.get("providers", {})
            # Support "starts with" Gemini 2.0
            and (self.is_gemini() and not self.is_gemini_1_5())
        )

    def _categorize_tool(
        self,
        acc: tuple[
            GoogleSearch | None, ToolCodeExecution | None, list[FunctionDeclaration]
        ],
        tool: ToolInfo,
    ) -> tuple[
        GoogleSearch | None, ToolCodeExecution | None, list[FunctionDeclaration]
    ]:
        """Reducer function that categorizes tools into native search vs function declarations.

        Returns:
            Tuple of (has_native_search, function_declarations) where has_native_search
            is True if any tool uses native search, and function_declarations contains
            all non-native-search tools converted to FunctionDeclaration objects.
        """
        return (
            (self._google_search_options(tool.options), acc[1], acc[2])
            if tool.options and self._use_native_search(tool)
            else (acc[0], ToolCodeExecution(), acc[2])
            if tool.options and self._use_native_code_execution(tool)
            else (
                acc[0],
                acc[1],
                acc[2]
                + [
                    FunctionDeclaration(
                        name=tool.name,
                        description=tool.description,
                        parameters=schema_from_param(tool.parameters)
                        if len(tool.parameters.properties) > 0
                        else None,
                    )
                ],
            )
        )

    def _google_search_options(self, options: dict[str, Any]) -> GoogleSearch:
        gemini_options = options.get("gemini", None)
        if isinstance(gemini_options, dict):
            return GoogleSearch.model_validate(gemini_options)
        else:
            return GoogleSearch()

    def chat_tools(self, tools: list[ToolInfo]) -> tuple[bool, ToolListUnion]:
        # cleave up tools (must use either native tools or client tools but not both)
        search_seed: GoogleSearch | None = None
        execution_seed: ToolCodeExecution | None = None
        google_search, code_execution, function_declarations = functools.reduce(
            self._categorize_tool,
            tools,
            (search_seed, execution_seed, list[FunctionDeclaration]()),
        )

        # native tools
        if google_search or code_execution:
            if function_declarations:
                raise ValueError(
                    "Gemini does not yet support native web search or code execution concurrently with other tools."
                )
            native_tools: ToolListUnion = []
            if google_search:
                native_tools.append(Tool(google_search=google_search))
            if code_execution:
                native_tools.append(Tool(code_execution=code_execution))
            return (True, native_tools)

        # client tools
        else:
            return (False, [Tool(function_declarations=function_declarations)])

    def _resolve_batcher(
        self, config: GenerateConfig, http_options: HttpOptions
    ) -> None:
        if self._batcher or not (batch_config := normalized_batch_config(config.batch)):
            return

        # create a dedicated client instance for the batcher
        client = Client(
            vertexai=self.is_vertex(),
            api_key=self.api_key,
            http_options=http_options,
            **self.model_args,
        )

        self._batcher = GoogleBatcher(
            client,
            batch_config,
            model_retry_config(
                self.model_name,
                config.max_retries,
                config.timeout,
                self.should_retry,
                lambda ex: None,
                log_model_retry,
            ),
            self.service_model_name(),
        )


def safety_settings_to_list(
    safety_settings: list[SafetySettingDict],
) -> list[SafetySetting]:
    settings: list[SafetySetting] = []
    for setting in safety_settings:
        settings.append(
            SafetySetting(category=setting["category"], threshold=setting["threshold"])
        )
    return settings


def build_model_call(
    contents: ContentListUnion | ContentListUnionDict,
    generation_config: GenerateContentConfig,
    safety_settings: list[SafetySettingDict],
    tools: ToolListUnion | None,
    tool_config: ToolConfig | None,
    system_instruction: list[File | Part | Image | str] | None,
    response: GenerateContentResponse | None,
    time: float | None,
) -> ModelCall:
    return ModelCall.create(
        request=dict(
            contents=contents,
            # the excluded fields are passed to the Python API as part of
            # GenerateContentConfig however they are passed separately in
            # the actual http request body, so reflect that here
            generation_config=generation_config.model_copy(
                update={
                    "safety_settings": None,
                    "tools": None,
                    "tool_config": None,
                    "system_instruction": None,
                }
            ),
            safety_settings=safety_settings,
            tools=tools if tools is not None else None,
            tool_config=tool_config if tool_config is not None else None,
            system_instruction=system_instruction,
        ),
        response=response if response is not None else {},
        filter=model_call_filter,
        time=time,
    )


def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
    if key == "inline_data" and isinstance(value, dict) and "data" in value:
        value = copy(value)
        value.update(data=BASE_64_DATA_REMOVED)
    return value


async def as_chat_messages(
    client: Client, messages: list[ChatMessage]
) -> list[Content]:
    # There is no "system" role in the `google-genai` package. Instead, system messages
    # are included in the `GenerateContentConfig` as a `system_instruction`. Strip any
    # system messages out.
    supported_messages = [message for message in messages if message.role != "system"]

    # build google chat messages
    chat_messages = [await content(client, message) for message in supported_messages]

    # combine consecutive tool messages
    chat_messages = functools.reduce(
        consecutive_tool_message_reducer, chat_messages, []
    )

    # return messages
    return chat_messages


def consecutive_tool_message_reducer(
    messages: list[Content],
    message: Content,
) -> list[Content]:
    if is_tool_message(message) and len(messages) > 0 and is_tool_message(messages[-1]):
        messages[-1] = Content(
            role="user", parts=(messages[-1].parts or []) + (message.parts or [])
        )
    else:
        messages.append(message)
    return messages


def is_tool_message(message: Content) -> bool:
    return (
        message.role == "user"
        and message.parts is not None
        and len(message.parts) > 0
        and message.parts[0].function_response is not None
    )


async def content(
    client: Client,
    message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
) -> Content:
    working_reasoning_block = None
    if isinstance(message, ChatMessageUser):
        if isinstance(message.content, str):
            return Content(
                role="user", parts=[await content_part(client, message.content)]
            )
        return Content(
            role="user",
            parts=(
                [await content_part(client, content) for content in message.content]
            ),
        )
    elif isinstance(message, ChatMessageAssistant):
        content_parts: list[Part] = []

        if isinstance(message.content, str):
            content_parts.append(Part(text=message.content or NO_CONTENT))
        else:
            for i, content in enumerate(message.content):
                if isinstance(content, ContentReasoning):
                    # if this is encrypted reasoning, save it for applying the thought_signature
                    # to the next part (don't emit a separate thought part during replay)
                    if content.redacted:
                        working_reasoning_block = content
                    else:
                        # unencrypted reasoning (for older models or debugging)
                        content_parts.append(Part(text=content.reasoning, thought=True))

                else:
                    # server side tool use
                    if isinstance(content, ContentToolUse):
                        parts_to_append = parts_from_server_tool_use(content)

                    # other content
                    else:
                        parts_to_append = [await content_part(client, content)]

                    # If previously there was a reasoning block, we need to set the "thought_signature"
                    # using the reasoning from that block.
                    # However, if there are tool calls in this message, the signature should go on
                    # the first tool call instead, not on text or server tool use parts
                    # (per Gemini API docs).
                    if (
                        working_reasoning_block is not None
                        and message.tool_calls is None
                    ):
                        if (
                            working_reasoning_block.reasoning is not None
                            and working_reasoning_block.redacted
                        ):
                            parts_to_append[0].thought_signature = base64.b64decode(
                                working_reasoning_block.reasoning.encode()
                            )
                        else:
                            logger.warning(
                                "Reasoning block must have a reasoning signature to set thought_signature."
                            )
                        # Now, reset the previous reasoning block.
                        working_reasoning_block = None
                    content_parts.extend(parts_to_append)

        # Now handle tool calls
        if message.tool_calls is not None:
            # Per Gemini API docs: thought_signature goes on the first tool call in a message.
            # For parallel function calls, only the first FC gets the signature.
            # For sequential function calls (multi-step), each step is a separate message,
            # so each will have its own reasoning block and signature.
            # The loop below applies the signature to the first tool call (when working_reasoning_block
            # is not None), then clears it so subsequent tool calls don't get it.
            for tool_call in message.tool_calls:
                # extract the part
                part = Part.from_function_call(
                    name=tool_call.function,
                    args=tool_call.arguments,
                )

                # handle reasoning block if available
                if working_reasoning_block is not None:
                    # tool call reasoning should always use a thought_signature
                    if (
                        working_reasoning_block.reasoning is not None
                        and working_reasoning_block.redacted
                    ):
                        part.thought_signature = base64.b64decode(
                            working_reasoning_block.reasoning.encode()
                        )
                    else:
                        logger.warning(
                            "Reasoning block must have a reasoning signature to set thought_signature."
                        )
                    working_reasoning_block = None

                content_parts.append(part)
        return Content(role="model", parts=content_parts)

    elif isinstance(message, ChatMessageTool):
        response = FunctionResponse(
            name=message.function,
            response={
                "content": (
                    message.error.message if message.error is not None else message.text
                )
            },
        )
        return Content(role="user", parts=[Part(function_response=response)])


async def content_part(client: Client, content: InspectContent | str) -> Part:
    if isinstance(content, str):
        return Part.from_text(text=content or NO_CONTENT)
    elif isinstance(content, ContentText):
        return Part.from_text(text=content.text or NO_CONTENT)
    elif isinstance(content, ContentReasoning):
        raise RuntimeError("content_part should never encounter ContentReasoning")
    elif isinstance(content, ContentData):
        raise RuntimeError("Google provider should never encounter ContentData")
    elif isinstance(content, ContentToolUse):
        raise RuntimeError("Google provider should never encounter ContentToolUse")
    else:
        return await chat_content_to_part(client, content)


async def chat_content_to_part(
    client: Client,
    content: ContentImage | ContentAudio | ContentVideo | ContentDocument,
) -> Part:
    if isinstance(content, ContentImage):
        content_bytes, mime_type = await file_as_data(content.image)
        return Part.from_bytes(mime_type=mime_type, data=content_bytes)
    else:
        file = await file_for_content(client, content)
        if file.uri is None:
            raise RuntimeError(f"Failed to get URI for file: {file.display_name}")
        return Part.from_uri(file_uri=file.uri, mime_type=file.mime_type)


async def extract_system_message_as_parts(
    client: Client,
    messages: list[ChatMessage],
) -> list[File | Part | Image | str] | None:
    system_parts: list[File | Part | Image | str] = []
    for message in messages:
        if message.role == "system":
            content = message.content
            if isinstance(content, str):
                system_parts.append(Part.from_text(text=content))
            elif isinstance(content, list):  # list[InspectContent]
                system_parts.extend(
                    [await content_part(client, content) for content in content]
                )
            else:
                raise ValueError(f"Unsupported system message content: {content}")
    # google-genai raises "ValueError: content is required." if the list is empty.
    return system_parts or None


# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
def schema_from_param(
    param: ToolParam | ToolParams,
    nullable: bool | None = False,
    description: str | None = None,
) -> Schema:
    if isinstance(param, ToolParams):
        param = ToolParam(
            type=param.type, properties=param.properties, required=param.required
        )

    # use fallback description if the param doesn't have its own
    param_description = param.description or description

    if param.type == "number":
        return Schema(
            type=Type.NUMBER, description=param_description, nullable=nullable
        )
    elif param.type == "integer":
        return Schema(
            type=Type.INTEGER, description=param_description, nullable=nullable
        )
    elif param.type == "boolean":
        return Schema(
            type=Type.BOOLEAN, description=param_description, nullable=nullable
        )
    elif param.type == "string":
        if param.format == "date-time":
            return Schema(
                type=Type.STRING,
                description=param_description,
                format="^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$",
                nullable=nullable,
            )
        elif param.format == "date":
            return Schema(
                type=Type.STRING,
                description=param_description,
                format="^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
                nullable=nullable,
            )
        elif param.format == "time":
            return Schema(
                type=Type.STRING,
                description=param_description,
                format="^[0-9]{2}:[0-9]{2}:[0-9]{2}$",
                nullable=nullable,
            )
        return Schema(
            type=Type.STRING, description=param_description, nullable=nullable
        )
    elif param.type == "array":
        return Schema(
            type=Type.ARRAY,
            description=param_description,
            items=schema_from_param(param.items) if param.items else None,
            nullable=nullable,
        )
    elif param.type == "object":
        return Schema(
            type=Type.OBJECT,
            description=param_description,
            properties={k: schema_from_param(v) for k, v in param.properties.items()}
            if param.properties is not None
            else {},
            required=param.required,
            nullable=nullable,
        )
    # convert unions to optional params if the second type is 'null'
    elif param.anyOf:
        if len(param.anyOf) == 2 and param.anyOf[1].type == "null":
            return schema_from_param(
                param.anyOf[0], nullable=True, description=param_description
            )
        else:
            return Schema(type=Type.TYPE_UNSPECIFIED, description=param_description)
    elif param.enum:
        return Schema(
            type=Type.STRING,
            format="enum",
            enum=param.enum,
            description=param_description,
        )
    else:
        return Schema(type=Type.TYPE_UNSPECIFIED, description=param_description)


def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
    if isinstance(tool_choice, ToolFunction):
        return ToolConfig(
            function_calling_config=FunctionCallingConfig(
                mode=FunctionCallingConfigMode.ANY,
                allowed_function_names=[tool_choice.name],
            )
        )
    else:
        return ToolConfig(
            function_calling_config=FunctionCallingConfig(
                mode=cast(FunctionCallingConfigMode, tool_choice.upper())
            )
        )


def completion_choice_from_candidate(
    model: str, candidate: Candidate
) -> ChatCompletionChoice:
    # content we'll return
    content: list[
        ContentText
        | ContentReasoning
        | ContentImage
        | ContentToolUse
        | ContentAudio
        | ContentVideo
        | ContentData
        | ContentDocument
    ] = []

    # google distributes reasoning text and thought_signature across multiple
    # content parts -- we need to consolidate this into a single ContentReasoning
    # to match our schema (we'll unroll it back into parts on replay)
    working_reasoning_block: ContentReasoning | None = None

    # content can be None when the finish_reason is SAFETY
    # content.parts can be None when the finish_reason is MALFORMED_FUNCTION_CALL
    if candidate.content is not None and candidate.content.parts is not None:
        # traverse parts
        parts = candidate.content.parts
        for i, part in enumerate(parts):
            if part.text is None and part.executable_code is None:
                continue  # We only care about text and executable_code here

            if part.code_execution_result is not None:
                continue  # We pickup code execution results with part.executable_code

            if part.text is not None and part.thought is True:
                # we'll create and append a reasoning block, saving a reference
                # to it so that we can ammend it with a thought signature if/when
                # one arrives later in the stream (note that multiple reasoning
                # parts without a signature can occur)
                working_reasoning_block = ContentReasoning(
                    reasoning=part.text,
                    redacted=False,
                )
                content.append(working_reasoning_block)
            else:
                # Check if this block has an associated thought_signature and
                # whether it corresponds to the previous ContentReasoning block.
                if part.thought_signature is not None:
                    if working_reasoning_block is None:
                        # append the reasoning block to the list
                        content.append(
                            ContentReasoning(
                                reasoning=base64.b64encode(
                                    part.thought_signature
                                ).decode(),
                                redacted=True,
                            )
                        )
                    else:
                        # attach the though_signature to the previous reasoning block
                        working_reasoning_block.summary = (
                            working_reasoning_block.reasoning
                        )
                        working_reasoning_block.reasoning = base64.b64encode(
                            part.thought_signature
                        ).decode()
                        working_reasoning_block.redacted = True
                        # clear it out
                        working_reasoning_block = None

                if part.text is not None:
                    content.append(ContentText(text=part.text))
                if part.executable_code is not None:
                    # lookahead for execution result
                    code_execution_result = (
                        parts[i + 1].code_execution_result
                        if i + 1 < len(parts)
                        else None
                    )
                    # append tool use
                    content.append(
                        server_tool_use_from_executable_code(
                            part.executable_code, code_execution_result
                        )
                    )

    # distribute citations to individual ContentText parts with adjusted indexes
    citations = get_candidate_citations(candidate)
    if citations:
        distribute_citations_to_text_parts(content, citations)

    # now tool calls
    tool_calls: list[ToolCall] = []
    if candidate.content is not None and candidate.content.parts is not None:
        for part in candidate.content.parts:
            if part.function_call:
                if (
                    part.function_call is None
                    or part.function_call.name is None
                    or part.function_call.args is None
                ):
                    raise ValueError(f"Incomplete function call: {part.function_call}")

                # If the part has a thought_signature, try and associate it with the previous working block
                if part.thought_signature:
                    if working_reasoning_block is None:
                        # We make the assumption that tool calls don't have independent reasoning
                        # blocks unless they are preceded by a reasoning block.
                        reasoning_block = ContentReasoning(
                            reasoning=base64.b64encode(part.thought_signature).decode(),
                            redacted=True,
                        )

                        content.append(reasoning_block)
                    else:
                        # attach the though_signature to the previous reasoning block
                        working_reasoning_block.summary = (
                            working_reasoning_block.reasoning
                        )
                        working_reasoning_block.reasoning = base64.b64encode(
                            part.thought_signature
                        ).decode()
                        working_reasoning_block.redacted = True
                        working_reasoning_block = None

                tool_calls.append(
                    ToolCall(
                        id=f"{part.function_call.name}_{uuid()}",
                        function=part.function_call.name,
                        arguments=part.function_call.args,
                    )
                )

    # stop reason
    stop_reason = finish_reason_to_stop_reason(
        candidate.finish_reason or FinishReason.STOP
    )

    # build choice
    choice = ChatCompletionChoice(
        message=ChatMessageAssistant(
            content=content if len(content) > 0 else "",
            tool_calls=tool_calls if len(tool_calls) > 0 else None,
            model=model,
            source="generate",
        ),
        stop_reason=stop_reason,
    )

    # add logprobs if provided
    if candidate.logprobs_result:
        logprobs: list[Logprob] = []
        if (
            candidate.logprobs_result.chosen_candidates
            and candidate.logprobs_result.top_candidates
        ):
            for chosen, top in zip(
                candidate.logprobs_result.chosen_candidates,
                candidate.logprobs_result.top_candidates,
            ):
                if chosen.token and chosen.log_probability:
                    logprobs.append(
                        Logprob(
                            token=chosen.token,
                            logprob=chosen.log_probability,
                            top_logprobs=[
                                TopLogprob(token=c.token, logprob=c.log_probability)
                                for c in (top.candidates or [])
                                if c.token and c.log_probability
                            ],
                        )
                    )
            choice.logprobs = Logprobs(content=logprobs)

    return choice


def completion_choices_from_candidates(
    model: str,
    response: GenerateContentResponse,
) -> list[ChatCompletionChoice]:
    candidates = response.candidates
    if candidates:
        candidates_list = sorted(candidates, key=lambda c: c.index or 0)
        return [
            completion_choice_from_candidate(model, candidate)
            for candidate in candidates_list
        ]
    elif response.prompt_feedback:
        return [
            ChatCompletionChoice(
                message=ChatMessageAssistant(
                    content=prompt_feedback_to_content(response.prompt_feedback),
                    model=model,
                    source="generate",
                ),
                stop_reason="content_filter",
            )
        ]
    else:
        return [
            ChatCompletionChoice(
                message=ChatMessageAssistant(
                    content=NO_CONTENT,
                    model=model,
                    source="generate",
                ),
                stop_reason="stop",
            )
        ]


def prompt_feedback_to_content(
    feedback: GenerateContentResponsePromptFeedback,
) -> str:
    content: list[str] = []
    block_reason = str(feedback.block_reason) if feedback.block_reason else "UNKNOWN"
    content.append(f"BLOCKED: {block_reason}")

    if feedback.block_reason_message is not None:
        content.append(feedback.block_reason_message)
    if feedback.safety_ratings is not None:
        content.extend(
            [rating.model_dump_json(indent=2) for rating in feedback.safety_ratings]
        )
    return "\n".join(content)


def usage_metadata_to_model_usage(
    metadata: GenerateContentResponseUsageMetadata | None,
) -> ModelUsage | None:
    if metadata is None:
        return None
    return ModelUsage(
        input_tokens=metadata.prompt_token_count or 0,
        output_tokens=metadata.candidates_token_count or 0,
        total_tokens=metadata.total_token_count or 0,
        reasoning_tokens=metadata.thoughts_token_count or 0,
    )


def server_tool_use_from_executable_code(
    executable_code: ExecutableCode, result: CodeExecutionResult | None
) -> ContentToolUse:
    # parse out output and error
    if result is not None:
        result_output = result.output or ""
        if result.outcome is not None and result.outcome != Outcome.OUTCOME_OK:
            result_error: str | None = result.outcome
        else:
            result_error = None
    else:
        result_output = ""
        result_error = None

    # return tool use
    return ContentToolUse(
        tool_type="code_execution",
        id="",
        name=executable_code.language or Language.LANGUAGE_UNSPECIFIED,
        arguments=executable_code.code or "",
        result=result_output,
        error=result_error,
    )


def parts_from_server_tool_use(tool: ContentToolUse) -> list[Part]:
    parts: list[Part] = [
        Part.from_executable_code(code=tool.arguments, language=Language(tool.name))
    ]
    if tool.result or tool.error:
        parts.append(
            Part.from_code_execution_result(
                outcome=Outcome(tool.error) if tool.error else Outcome.OUTCOME_OK,
                output=tool.result,
            )
        )
    return parts


def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
    match finish_reason:
        case FinishReason.STOP:
            return "stop"
        case FinishReason.MAX_TOKENS:
            return "max_tokens"
        case (
            FinishReason.SAFETY
            | FinishReason.RECITATION
            | FinishReason.BLOCKLIST
            | FinishReason.PROHIBITED_CONTENT
            | FinishReason.SPII
        ):
            return "content_filter"
        case _:
            # Note: to avoid adding another option to StopReason,
            # this includes FinishReason.MALFORMED_FUNCTION_CALL
            return "unknown"


def parse_safety_settings(
    safety_settings: Any,
) -> list[SafetySettingDict]:
    # ensure we have a dict
    if isinstance(safety_settings, str):
        safety_settings = json.loads(safety_settings)
    if not isinstance(safety_settings, dict):
        raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")

    parsed_settings: list[SafetySettingDict] = []
    for key, value in safety_settings.items():
        if not isinstance(key, str):
            raise ValueError(f"Unexpected type for harm category: {key}")
        if not isinstance(value, str):
            raise ValueError(f"Unexpected type for harm block threshold: {value}")
        key = str_to_harm_category(key)
        value = str_to_harm_block_threshold(value)
        parsed_settings.append({"category": key, "threshold": value})
    return parsed_settings


def str_to_harm_category(category: str) -> HarmCategory:
    category = category.upper()
    # `in` instead of `==` to allow users to pass in short version e.g. "HARASSMENT" or
    # long version e.g. "HARM_CATEGORY_HARASSMENT" strings.
    if "CIVIC_INTEGRITY" in category:
        return HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY
    if "DANGEROUS_CONTENT" in category:
        return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
    if "HATE_SPEECH" in category:
        return HarmCategory.HARM_CATEGORY_HATE_SPEECH
    if "HARASSMENT" in category:
        return HarmCategory.HARM_CATEGORY_HARASSMENT
    if "SEXUALLY_EXPLICIT" in category:
        return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
    if "UNSPECIFIED" in category:
        return HarmCategory.HARM_CATEGORY_UNSPECIFIED
    raise ValueError(f"Unknown HarmCategory: {category}")


def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
    threshold = threshold.upper()
    if "LOW" in threshold:
        return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
    if "MEDIUM" in threshold:
        return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
    if "HIGH" in threshold:
        return HarmBlockThreshold.BLOCK_ONLY_HIGH
    if "NONE" in threshold:
        return HarmBlockThreshold.BLOCK_NONE
    if "OFF" in threshold:
        return HarmBlockThreshold.OFF
    raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")


async def file_for_content(
    client: Client, content: ContentAudio | ContentVideo | ContentDocument
) -> File:
    # helper to write trace messages
    def trace(message: str) -> None:
        trace_message(logger, "Google Files", message)

    # get the file bytes and compute sha256 hash
    if isinstance(content, ContentAudio):
        file = content.audio
    elif isinstance(content, ContentVideo):
        file = content.video
    else:
        file = content.document
    content_bytes, mime_type = await file_as_data(file)
    content_sha256 = hashlib.sha256(content_bytes).hexdigest()
    # we cache uploads for re-use, open the db where we track that
    # (track up to 1 million previous uploads)
    with inspect_kvstore("google_files", 1000000) as files_db:
        # can we serve from existing uploads?
        uploaded_file = files_db.get(content_sha256)
        if uploaded_file:
            try:
                upload: File = client.files.get(name=uploaded_file)
                assert upload.state
                if upload.state.name == "ACTIVE":
                    trace(f"Using uploaded file: {uploaded_file}")
                    return upload
                else:
                    trace(
                        f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
                    )
            except Exception as ex:
                trace(f"Error attempting to access uploaded file: {ex}")
                files_db.delete(content_sha256)
        # do the upload (and record it)
        upload = client.files.upload(
            file=BytesIO(content_bytes), config=dict(mime_type=mime_type)
        )
        while upload.state.name == "PROCESSING":  # type: ignore[union-attr]
            await anyio.sleep(3)
            assert upload.name
            upload = client.files.get(name=upload.name)
        if upload.state.name == "FAILED":  # type: ignore[union-attr]
            trace(f"Failed to upload file '{upload.name}: {upload.error}")
            raise ValueError(f"Google file upload failed: {upload.error}")
        # trace and record it
        trace(f"Uploaded file: {upload.name}")
        files_db.put(content_sha256, str(upload.name))
        # return the file
        return upload
